# -*- coding: utf-8 -*-

from typing import Tuple
import random
import typing
import Cryptodome
from Cryptodome.Cipher import DES, DES3, AES
from Cryptodome.Hash import SHA256, SHA1, MD5
from Cryptodome.Cipher import PKCS1_v1_5
from Cryptodome.Util import number
from Cryptodome.Signature import pkcs1_15
from Cryptodome.PublicKey import RSA
from gmssl.sm4 import CryptSM4, SM4_ENCRYPT, SM4_DECRYPT
from gmssl.sm3 import sm3_hash
from gmssl.sm2 import CryptSM2
from gmssl import func
from gmssl import sm2_key_gen


# ============================== 非对称 ================================
class SM2:
    PublicKeySize = 64
    PrivateKeySize = 32
    RandomSize = 32
    SignSize = 64
    ENTL = bytes.fromhex("0080")
    ID = bytes.fromhex("31323334353637383132333435363738")
    A = bytes.fromhex("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC")
    B = bytes.fromhex("28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93")
    XG = bytes.fromhex("32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7")
    YG = bytes.fromhex("BC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0")

    def __init__(self, keyValue:bytes, otherKeyValue=None) -> None:
        """
        根据key的长度，自动确定是公钥、还是私钥
        加密、解密、验签只需要1个密钥，此时，只需传入keyValue
        而签名需要2个密钥，此时，可传入otherKeyValue
        """
        self.publicKey = None
        self.privateKey = None
        if len(keyValue) == self.PublicKeySize:
            self.publicKey = keyValue
        else:
            self.privateKey = keyValue
        if otherKeyValue is not None:
            if len(otherKeyValue) == self.PublicKeySize:
                self.publicKey = otherKeyValue
            else:
                self.privateKey = otherKeyValue            
     
    
    @classmethod
    def random(cls) -> bytes:
        """
        产生符合SM2加密、签名长度要求的随机数
        """
        return randomBytes(cls.PrivateKeySize)

    @classmethod
    def genKey(cls) -> Tuple[bytes, bytes]:
        priKey = sm2_key_gen.PrivateKey()
        pubKey = priKey.publicKey()
        return (bytes.fromhex(pubKey.toString(compressed = False)), bytes.fromhex(priKey.toString()))  

    def encrypt(self, data:bytes, *, random:bytes = None, randomHook=None) -> bytes:
        """
        randomHook:通过randomHook可以导出加密过程中的随机数，def randomHook(random:bytes) -> None
        """
        cipher = CryptSM2(public_key=self.publicKey.hex(), private_key=None)
        if random is None:
            random = self.random()
        if randomHook is not None:
            randomHook(random)
        return cipher.encrypt(data, random.hex())

    def decrypt(self, data:bytes) -> bytes:
        cipher = CryptSM2(public_key=None, private_key=self.privateKey.hex())
        return cipher.decrypt(data)

    def sign(self, message:bytes, *, ida:bytes = b'1234567812345678', random:bytes = None, zHook=None, randomHook=None, digestHook=None) -> bytes:
        """
        包含预处理的签名
        digestHook，用于获取签名过程中的Hash值，def digestHook(digest:bytes) -> None
        """
        digest = self.pre(message, ida=ida, zHook=zHook)
        if digestHook is not None:
            digestHook(digest)
        if random is None:
            random = self.random()
        if randomHook is not None:
            randomHook(random)
        return self.signDigest(digest, random=random)

    def signDigest(self, digest:bytes, *, random:bytes = None) -> bytes:
        """
        不包含预处理的签名
        """
        cipher = CryptSM2(public_key=None, private_key=self.privateKey.hex())
        return bytes.fromhex(cipher.sign(digest, random.hex()))

    def verify(self, sign:bytes, message:bytes, *, ida:bytes = b'1234567812345678', zHook=None, digestHook=None) -> bool:
        """
        包含预处理的验签
        digestHook，用于获取签名过程中的Hash值，def digestHook(digest:bytes) -> None
        """
        digest = self.pre(message, ida=ida, zHook=zHook)
        if digestHook is not None:
            digestHook(digest)
        return self.verifyDigest(sign, digest)    

    def verifyDigest(self, sign:bytes, digest:bytes) -> bool:
        """
        不包含预处理的验签
        """
        cipher = CryptSM2(public_key=self.publicKey.hex(), private_key=None)
        return cipher.verify(sign.hex(), digest)

    def getZ(self, ida:bytes = b'1234567812345678'):
        ida = ida if ida is not None else self.ID
        idsize = len(ida) * 8
        ENTL = idsize.to_bytes(2, 'big')
        z = SM3(ENTL + ida + self.A + self.B + self.XG + self.YG + self.publicKey).digest()
        return z       

    def pre(self, message:bytes, *, ida:bytes = b'1234567812345678', zHook=None) -> bytes:
        """
        签名预处理
        Z = SM3( ENTL || ID || a || b || xG || yG || x || y) )
        digest =  SM3(Z || M) 
        """
        z = self.getZ(ida)
        if zHook is not None:
            zHook(z)
        return SM3(z + message).digest()


class RSA:
    DefaultE = 65537    

    def __init__(self, n:bytes, e:int=65537, *, d=None, pq=None) -> None:
        self.n = n
        self.e = e if e is not None else self.DefaultE
        self.d = d
        self.pq = pq
        self.p = None
        self.q = None
        self.dp = None
        self.dq = None
        self.qinv = None
        self.modulusSize = len(n) * 8
        if self.pq is None:
            return
        self.p = pq[:len(n) / 2]
        self.q = pq[len(n) / 2:]
        n, _, _, _, _, self.dp, self.dq, self.qinv = self.calcKey()
        if n != self.n:
            raise Exception('公私钥不匹配')
        

    @classmethod
    def genKey(cls, modulusSize:int=2048, e:int=65537):
        """
        return ((n, e), (d, p, q, dp, dq, qinv)),e - int,其他 - bytes
        """
        e = e if e is not None else cls.DefaultE
        key = Cryptodome.PublicKey.RSA.generate(modulusSize, e=e)
        dp = key.d % (key.p - 1)
        dq = key.d % (key.q - 1)
        qinv = number.inverse (key.q, key.p)
        n = key.n.to_bytes(modulusSize // 8, 'big')
        d = key.d.to_bytes(modulusSize // 8, 'big')
        p = key.p.to_bytes(modulusSize // 8 // 2, 'big')
        q = key.q.to_bytes(modulusSize // 8 // 2, 'big')
        dp = dp.to_bytes(modulusSize // 8 // 2, 'big')
        dq = dq.to_bytes(modulusSize // 8 // 2, 'big')
        qinv = qinv.to_bytes(modulusSize // 8 // 2, 'big')
        return ((n, e), (d, p, q, dp, dq, qinv))

    @classmethod
    def calcKey(cls, p:bytes, q:bytes, e:int=65537):
        """
        return ((n, e), (d, p, q, dp, dq, qinv)),e - int,其他 - bytes
        """
        p = int.from_bytes(p, byteorder='big')
        q = int.from_bytes(q, byteorder='big')        
        e = e if e is not None else cls.DefaultE
        phi = (p - 1)*(q - 1)
        lcm = phi // number.GCD((p - 1), (q - 1))
        d = number.inverse(e, lcm)      # 与Cryptodome中RsaKey.d一致、与大师密钥补全计算结果一致
        #d = number.inverse(e, phi)     # 虽然值与上面不一样，但是似乎也是可以的，与RsaTool的Calculate D结果一致
        dp = d % (p - 1)
        dq = d % (q - 1)        
        qinv = number.inverse(q, p)  
        n = p * q
        return ((n, e), (d, p, q, dp, dq, qinv))      

    def encrypt(self, data:bytes, *, pad:bool=True, usePrivateKey:bool=False) -> bytes:
        """
        usePrivateKey，是否用私钥加密，True时pad无效
        pad，按 PKCS#1 V1.5
        """
        if usePrivateKey:
            return self._sk(data)
        else:
            if pad:
                data = self._padEncrypt(data)
            return self._pk(data)

    def decrypt(self, data:bytes, *, pad:bool=True, usePublicKey:bool=False) -> bytes:
        """
        usePublicKey，是否用公钥解密，True时pad无效
        pad，按 PKCS#1 V1.5
        """
        if usePublicKey:
            return self._pk(data)
        else:
            data = self._sk(data)
            print('decrypt: ', data.hex())
            if pad:
                data = self._unpadEncrypt(data)
            return data

    def sign(self, message:bytes, hash, *, pad:bool=True, digestHook=None) -> bytes:
        """
        hash，签名时用的hash算法，可以为crypto.SHA1/SHA256/MD5
        digestHook，用于获取签名过程中的Hash值，def digestHook(digest:bytes) -> None
        """       
        digest = hash(message).digest()
        if digestHook is not None:
            digestHook(digest)

        return self.signDigest(digest, pad=pad)

    def signDigest(self, digest:bytes, *, pad:bool=True) -> bytes:
        """
        自动填充时，根据digest自动选择hash算法OID，hash算法可以为SHA1/SHA256/MD5
        """
        data = digest
        if pad:
            data = self._padSign(digest)
        return self._sk(data)

    def verify(self, sign:bytes, message:bytes, *, pad:bool=True, digestHook=None) -> bool:
        """
        digestHook，用于获取签名过程中的Hash值，def digestHook(digest:bytes) -> None
        """
        digest = self._pk(sign)
        if pad:
            digest = self._unpadSign(digest)
        if digestHook is not None:
            digestHook(digest)

        hash = self._getHashByDigest(digest)
        return hash(message).digest() == digest      

    def verifyDigest(self, sign:bytes, digest:bytes, *, pad:bool=True) -> bool:
        """
        自动填充时，根据digest自动选择hash算法OID，hash算法可以为SHA1/SHA256/MD5
        """        
        data = self._pk(sign)
        if pad:
            data = self._unpadSign(data)
        return data == digest


    def _cutZero(self, data: bytes) -> bytes:
        for i in range(0, len(data)):
            if data[i] != 0:
                return data[i:]
        return data    

    def _pk(self, data:bytes) -> bytes:
        """
        公钥运算
        """
        error = ''
        n = int.from_bytes(self.n, byteorder='big')
        input = int.from_bytes(data, byteorder='big')
        m = input
        if input >= n or input == 0:
            error += '数据应小于N且不为0\n'
        if len(error) != 0:
            raise Exception(error)

        c = pow(m, self.e, n)
        return self._cutZero(c.to_bytes(self.modulusSize // 8, 'big'))

    def _sk(self, data:bytes) -> bytes:
        """
        私钥运算
        """
        error = ''
        n = int.from_bytes(self.n, byteorder='big')
        d = int.from_bytes(self.d, byteorder='big')
        input = int.from_bytes(data, byteorder='big')
        m = input
        if input >= n or input == 0:
            error += '数据应小于N且不为0\n'
        if len(error) != 0:
            raise Exception(error)            

        c = pow(m, d, n)
        return self._cutZero(c.to_bytes(self.modulusSize // 8, 'big'))

    def _padEncrypt(self, data:bytes) -> bytes:
        """
        公钥加密填充：PKCS# V1.5
            a) 生成一个长度为 K - mLen - 3 的非零随机字节串 PS， PS 的长度至少为 8 字节。
            b) 按照如下方式连接 PS 及 M，产生 K 字节编码消息 EM
            EM = 0x00 | 0x02 | PS | 0x00 | M

        """
        k = self.modulusSize // 8
        if len(data) > k - 11:
            raise Exception(f'自动填充模式下，明文数据最长为 {k - 11}，现长 {len(data)}')
        ps = randomBytesNoZero(k - len(data) - 3)
        return bytes.fromhex('0002') + ps + bytes.fromhex('00') + data

    def _unpadEncrypt(self, data:bytes) -> bytes:
        for i in range(-1, -len(data), -1):
            if data[i] == 0:
                return data[i+1:]
        raise Exception('错误的填充格式')

    def _padSign(self, digest:bytes) -> bytes:
        """
        签名填充：PKCS# V1.5
        消息编码：
            a) 生成长度为 tLen 的 H 信息的 DER 编码 T：
                T = HashOID || H
                HashOID:
                    MD5:3020300C06082A864886F70D020505000410
                    SHA1:3021300906052B0E03021A05000414
                    SHA256:3031300D060960864801650304020105000420
            b) 生成一个由 K - tLen - 3 个十六进制值 0xFF 组成的字节串
            c) 按照如下方式连接 PS 及 T，产生 K 字节编码消息 EM
                EM = 0x00 | 0x01 | PS | 0x00 | T
        """
        oids = {
            MD5: bytes.fromhex('3020300C06082A864886F70D020505000410'),
            SHA1:bytes.fromhex('3021300906052B0E03021A05000414'),
            SHA256:bytes.fromhex('3031300D060960864801650304020105000420')
        }
        hash = self._getHashByDigest(digest)
        oid = oids[hash]
        t = oid + digest
        k = self.modulusSize // 8
        ps = bytes([0xFF]*(k - len(t) - 3))
        return bytes.fromhex('0001') + ps + bytes.fromhex('00') + t

    def _unpadSign(self, data:bytes) -> bytes:
        if data[-(Hash.MD5Size + 1)] == Hash.MD5Size:
            return data[-Hash.MD5Size:]
        if data[-(Hash.SHA1Size + 1)] == Hash.SHA1Size:
            return data[-Hash.SHA1Size:]
        if data[-(Hash.SHA256Size + 1)] == Hash.SHA256Size:
            return data[-Hash.SHA256Size:]
        raise Exception('不支持的HASH算法')

    def _getHashByDigest(self, digest):
        if len(digest) == Hash.MD5Size:
            return MD5
        elif len(digest) == Hash.SHA1Size:
            return SHA1
        elif len(digest) == Hash.SHA256Size:
            return SHA256
        raise Exception('不支持的HASH算法')

# ============================== HASH ================================
class Hash:
    MD5Size = 16
    SHA1Size = 20
    SHA256Size = 32
    SM3Size = 32

    def __init__(self, data:bytes) -> None:
        self.data = data
        self.cipher = None

    def digest(self) -> bytes:
        if self.cipher == None:
            raise Exception('算法未指定')
        return self.cipher.digest()

class SM3(Hash):
    Size = Hash.SM3Size
    
    def __init__(self, data: bytes) -> None:
        super().__init__(data)

    def digest(self) -> bytes:
        return bytes.fromhex(sm3_hash(func.bytes_to_list(self.data)))

class MD5(Hash):
    Size = Hash.MD5Size

    def __init__(self, data: bytes) -> None:
        super().__init__(data)
        self.cipher = Cryptodome.Hash.MD5.new(self.data)

class SHA1(Hash):
    Size = Hash.SHA1Size
    
    def __init__(self, data: bytes) -> None:
        super().__init__(data)
        self.cipher = Cryptodome.Hash.SHA1.new(self.data)

class SHA256(Hash):
    Size = Hash.SHA256Size
    
    def __init__(self, data: bytes) -> None:
        super().__init__(data)
        self.cipher = Cryptodome.Hash.SHA256.new(self.data)    



# ============================== 对称 ================================
class Symmetric:
    DESKeySize = 8
    DES3KeySize = 16
    DES33KeySize = 24    
    AESKeySize = 16
    SM4KeySize = 16
    DESBlockSize = 8
    DES3BlockSize = 8
    DES33BlockSize = 8    
    AESBlockSize = 16
    SM4BlockSize = 16

    def __init__(self, keyValue:bytes, *, iv=None) -> None:
        self.key = keyValue
        self.iv = iv
        self.cipher = None
        self.keySize = 0
        self.blockSize = 0

    def _init_check(self):
        if self.iv is not None:
            if len(self.iv) != self.blockSize:
                raise Exception(f'IV长应为{self.blockSize}')
        
        if self.key is None or len(self.key) != self.keySize:
            raise Exception(f'Key长度应为{self.keySize}')

    def _check(self, data, pad):
        if self.cipher == None:
            raise Exception('未指定算法')
        if data is None or len(data) == 0:
            raise Exception('未指定数据')
        if not pad and len(data) % self.blockSize != 0:
            raise Exception(f'数据长应为{self.blockSize}的整数倍')            

    def encrypt(self, data:bytes, *, pad=False) -> bytes:
        """
        加密，IV不为None：CBC模式，否则，ECB模式
        pad为True时的填充方式：在data最后加0x80，然后填充0x00直至为分组长的整数倍
        """
        self._check(data, pad)
        if pad:
            data = self._pad(data)
        return self.cipher.encrypt(data)        
        

    def decrypt(self, data:bytes, *, pad=False) -> bytes:
        """
        解密，IV不为None：CBC模式，否则，ECB模式
        pad为True时的填充方式：在data最后加0x80，然后填充0x00直至为分组长的整数倍，解密后按此规则去除填充值
        """
        self._check(data, pad)
        data = self.cipher.decrypt(data)
        if pad:
            data = self._unpad(data)
        return data

    def _pad(self, data:bytes) -> bytes:
        size = len(data) // self.blockSize * self.blockSize + self.blockSize
        padded = [0 for x in range(0, size)]
        padded[len(data)] = 0x80
        padded[:len(data)] = list(data)
        return bytes(padded)

    def _unpad(self, data:bytes) -> bytes:
        for i in range(-1, -len(data), -1):
            if data[i] == 0x80:
                return data[:i]
        return None

class DES(Symmetric):
    KeySize = Symmetric.DESKeySize
    BlockSize = Symmetric.DESBlockSize

    def __init__(self, keyValue:bytes, *, iv=None) -> None:
        super().__init__(keyValue, iv=iv)
        self.keySize = self.KeySize
        self.blockSize = self.BlockSize
        self._init_check()
        if iv is None:
            self.cipher = Cryptodome.Cipher.DES.new(self.key, Cryptodome.Cipher.DES.MODE_ECB)
        else:
            self.cipher = Cryptodome.Cipher.DES.new(self.key, Cryptodome.Cipher.DES.MODE_CBC, iv=iv)

class DES3(Symmetric):
    KeySize = Symmetric.DES3KeySize
    BlockSize = Symmetric.DES3BlockSize

    def __init__(self, keyValue:bytes, *, iv=None) -> None:
        super().__init__(keyValue, iv=iv)
        self.keySize = self.KeySize
        self.blockSize = self.BlockSize
        self._init_check()        
        self.key = self.key + self.key[:8]
        if iv is None:
            self.cipher = Cryptodome.Cipher.DES3.new(self.key, Cryptodome.Cipher.DES3.MODE_ECB)
        else:
            self.cipher = Cryptodome.Cipher.DES3.new(self.key, Cryptodome.Cipher.DES3.MODE_CBC, iv=iv)
        self.keySize = self.KeySize
        self.blockSize = self.BlockSize    

class DES33(Symmetric):
    KeySize = Symmetric.DES33KeySize
    BlockSize = Symmetric.DES33BlockSize

    def __init__(self, keyValue:bytes, *, iv=None) -> None:
        super().__init__(keyValue, iv=iv)
        self.keySize = self.KeySize
        self.blockSize = self.BlockSize
        self._init_check()        
        if iv is None:
            self.cipher = Cryptodome.Cipher.DES3.new(self.key, Cryptodome.Cipher.DES3.MODE_ECB)
        else:
            self.cipher = Cryptodome.Cipher.DES3.new(self.key, Cryptodome.Cipher.DES3.MODE_CBC, iv=iv)
    

class AES(Symmetric):
    KeySize = Symmetric.AESKeySize
    BlockSize = Symmetric.AESBlockSize

    def __init__(self, keyValue:bytes, *, iv=None) -> None:
        super().__init__(keyValue, iv=iv)
        self.keySize = self.KeySize
        self.blockSize = self.BlockSize
        self._init_check()        
        if iv is None:
            self.cipher = Cryptodome.Cipher.AES.new(self.key, Cryptodome.Cipher.AES.MODE_ECB)
        else:
            self.cipher = Cryptodome.Cipher.AES.new(self.key, Cryptodome.Cipher.AES.MODE_CBC, iv=iv)
   

class SM4(Symmetric):
    KeySize = Symmetric.SM4KeySize
    BlockSize = Symmetric.SM4BlockSize

    def __init__(self, keyValue:bytes, *, iv=None) -> None:
        super().__init__(keyValue, iv=iv)
        self.keySize = self.KeySize
        self.blockSize = self.BlockSize
        self._init_check()        
        self.cipher = self.Cipher(self.key, iv=iv)        

    class Cipher:
        def __init__(self, keyValue:bytes, *, iv=None) -> None:
            self.sm4 = CryptSM4()
            self.key = keyValue
            self.iv = iv

        def encrypt(self, data):
            self.sm4.set_key(self.key, SM4_ENCRYPT)
            if self.iv is None:
                return self.sm4.crypt_ecb(data)
            else:
                return self.sm4.crypt_cbc(self.iv , data)

        def decrypt(self, data):
            self.sm4.set_key(self.key, SM4_DECRYPT)
            if self.iv is None:
                return self.sm4.crypt_ecb(data)
            else:
                return self.sm4.crypt_cbc(self.iv , data)



# ============================== 随机数 ================================
def randomBytes(size:int) -> bytes:
    """
    产生指定长度的随机bytes
    """
    return bytes([random.randint(0, 255) for x in range(0, size)])

def randomBytesNoZero(size:int) -> bytes:
    """
    产生指定长度的随机bytes
    """
    return bytes([random.randint(1, 255) for x in range(0, size)])